import jax.numpy as jnp
import jax.nn as nn
import numpy as np

from src.utils import *
from flax import linen as nn
from typing import NamedTuple, Optional,Any


def eluplus1(x): #elu+1 kernel function
        return nn.elu(x)+1

class PosEmbLayer(nn.Module):
    d_model:int
    @nn.compact
    def __call__(self,inputs):
        div_term=jnp.exp(jnp.arange(0, self.d_model, 2)
                                    * (-jnp.log(10000.0) / self.d_model))
        # inputs u_t^{i-1} shape T X d_model
        truncation,_=inputs.shape
        
        pos_emb=jnp.zeros((truncation,self.d_model))

        position = 1+jnp.arange(0,truncation).reshape(-1,1)
        pos_emb=pos_emb.at[:,0::2].set(jnp.sin(position*div_term))
        pos_emb=pos_emb.at[:,1::2].set(jnp.cos(position*div_term))
        inputs_embed=inputs+pos_emb #x_t+pos_emb)
        return inputs_embed

class LayerEmbLayer(nn.Module):
    d_model:int
    
    @nn.compact
    def __call__(self,inputs,layer_id):
        div_term=jnp.exp(jnp.arange(0, self.d_model, 2)
                                    * (-jnp.log(10000.0) / self.d_model))
        # inputs u_t^{i-1} shape T X d_model
        truncation=inputs.shape[0]
        pos_emb=jnp.zeros((truncation,self.d_model))
        position =layer_id*jnp.ones((truncation)).reshape(-1,1)
        pos_emb=pos_emb.at[:,0::2].set(jnp.sin(position*div_term))
        pos_emb=pos_emb.at[:,1::2].set(jnp.cos(position*div_term))
        inputs_embed=inputs+pos_emb #x_t+pos_emb
        return inputs_embed

class TSOPState(NamedTuple):
    c:jnp.array
    s:jnp.array
        
class CumulativeSumOuterProduct(nn.Module):
    d_model:int
    kernel_phi:Any=eluplus1

    @nn.compact
    def __call__(self,inputs):
        #Concatenate new inputs from the right 
        # inputs u_t^{i-1} shape T X d_model
        
        keys=self.kernel_phi(nn.Dense(self.d_model)(inputs)) #TXkernel_dim
        values=nn.Dense(self.d_model)(inputs) #TXd_model

        outer_product_fun=jax.vmap(lambda v,k: jnp.dot(v.reshape(-1,1),k.reshape(1,-1)))
        c_t=jnp.cumsum(outer_product_fun(values,keys),axis=0) #TXd_modelXkernel_dim
        s_t=jnp.cumsum(keys,axis=0) #TXd_model
        state=TSOPState(c=c_t,s=s_t)
        return state


class LinearTransformerEncoder(nn.Module):
    d_model:int
    d_ffc:int
    n_heads:int 
    kernel_phi:Any=eluplus1
        
    @nn.compact
    def __call__(self,inputs, use_dense=False,pos_emb_type:str=None):
        truncation,_=inputs.shape
        # inputs u_t^{i-1} shape T X d_model
        def attention_func(c_t,s_t,Q_t):
            norm_fun=jax.vmap(lambda s,q: jnp.dot(s,q))
            attention_fun=jax.vmap(lambda c,q:jnp.dot(c,q))
            norm=norm_fun(s_t,Q_t)
            attention_sum=attention_fun(c_t,Q_t)
            return attention_sum/norm.reshape(-1,1)

        #Calculation starts here
        #Input-key outer product for n heads
        if use_dense:
            inputs_enc=nn.Dense(self.d_model,name='emb_layer')(inputs)
        else:
            inputs_enc=inputs
        if pos_emb_type=='absolute':
            inputs_embed=PosEmbLayer(self.d_model,name='pos_emb')(inputs_enc)
        else:
            inputs_embed=inputs_enc
        inputs_repeat=jnp.repeat(jnp.expand_dims(inputs_embed,axis=0),repeats=self.n_heads,axis=0)
        csop_mh=nn.vmap(CumulativeSumOuterProduct,in_axes=0, out_axes=0,
                                variable_axes={'params': 0},
                                split_rngs={'params': True})(d_model=self.d_model,kernel_phi=self.kernel_phi,name='csop')
        state=csop_mh(inputs_repeat)
        query_layer_mh=nn.vmap(nn.Sequential,in_axes=0, out_axes=0,
                                variable_axes={'params': 0},
                                split_rngs={'params': True})([nn.Dense(self.d_model),self.kernel_phi],name='query')
        Q_t=query_layer_mh(inputs_repeat)
        attn_out=jax.vmap(attention_func)(state.c,state.s,Q_t) #n_headsXTXd_model
        #Combine output of n heads
        attn_out=jnp.transpose(attn_out,(1,0,2)).reshape(truncation,-1) #TXd_model*n_heads
        attn_out=nn.Dense(self.d_model)(attn_out) #TXd_model
        attn_out=nn.LayerNorm()(attn_out+inputs_embed)
        #Add only previous output because this is a decoder
        ffc=nn.Sequential([nn.Dense(self.d_ffc),jax.nn.relu,nn.Dense(self.d_model)])
        out=nn.LayerNorm()(ffc(attn_out)+attn_out)
        return out

class UniversalLinearTransformer(nn.Module):
    d_model:int
    d_ffc:int
    n_heads:int 
    truncation:int
    n_layers:int
    kernel_phi:Any=eluplus1
    

    @nn.compact
    def __call__(self,inputs):
        """Remembers the last T inputs

        Args:
            inputs (_type_): _description_

        Returns:
            _type_: _description_
        """
        #inputs x_t: input_dims
        #Apply projection layer
        inputs_last=nn.Dense(self.d_model,)(inputs)
        #Apply posemb
        posemblayer=PosEmbLayer(self.d_model)
        layeremblayer=LayerEmbLayer(self.d_model)
        encoder=LinearTransformerEncoder(self.d_model,self.d_ffc,self.n_heads,kernel_phi=self.kernel_phi)
        for layer in range(self.n_layers):
            inputs_emb=posemblayer(inputs_last) #Apply position and layer emb layer
            inputs_emb=layeremblayer(inputs_last,layer)
            y_t=encoder(inputs_emb)
            inputs_last=y_t
        return y_t


class LinearTransformer(nn.Module):
    d_model:int
    d_ffc:int
    n_heads:int 
    truncation:int
    n_layers:int
    pos_emb_type:str='absolute'
    kernel_phi:Any=eluplus1
    
    @nn.compact
    def __call__(self,inputs):
        """Remembers the last T inputs

        Args:
            inputs (_type_): _description_

        Returns:
            _type_: _description_
        """
        #inputs x_t: input_dims
        #Apply projection layer
        u_i=inputs
        #Apply posemb
        for layer in range(self.n_layers):
            encoder=LinearTransformerEncoder(d_model=self.d_model,d_ffc=self.d_ffc,n_heads=self.n_heads,
                                kernel_phi=self.kernel_phi,name='layer%d'%(layer))
            u_i=encoder(u_i,use_dense=(True if layer==0 else False),
                        pos_emb_type=(self.pos_emb_type if layer==0 else None))
        return u_i